{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install gymnasium\n",
    "!pip install 'gymnasium[mujoco]'\n",
    "!pip install matplotlib\n",
    "!pip3 install torch torchvision torchaudio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gymnasium as gym\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "from IPython import display"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/rocketsky/miniconda3/envs/mac-rl/lib/python3.10/site-packages/gymnasium/envs/registration.py:517: DeprecationWarning: \u001b[33mWARN: The environment Ant-v4 is out of date. You should consider upgrading to version `v5`.\u001b[0m\n",
      "  logger.deprecation(\n"
     ]
    }
   ],
   "source": [
    "import gymnasium as gym\n",
    "import math\n",
    "import random\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import namedtuple, deque\n",
    "from itertools import count\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "\n",
    "env = gym.make(\"Ant-v4\",render_mode=\"human\")\n",
    "\n",
    "# set up matplotlib\n",
    "is_ipython = 'inline' in matplotlib.get_backend()\n",
    "if is_ipython:\n",
    "    from IPython import display\n",
    "\n",
    "plt.ion()\n",
    "\n",
    "# if gpu is to be used\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "Transition = namedtuple('Transition',\n",
    "                        ('state', 'action', 'next_state', 'reward'))\n",
    "\n",
    "\n",
    "class ReplayMemory(object):\n",
    "\n",
    "    def __init__(self, capacity):\n",
    "        self.memory = deque([], maxlen=capacity)\n",
    "\n",
    "    def push(self, *args):\n",
    "        \"\"\"Save a transition\"\"\"\n",
    "        self.memory.append(Transition(*args))\n",
    "\n",
    "    def sample(self, batch_size):\n",
    "        return random.sample(self.memory, batch_size)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.memory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DQN(nn.Module):\n",
    "\n",
    "    def __init__(self, n_observations, n_actions):\n",
    "        super(DQN, self).__init__()\n",
    "        self.layer1 = nn.Linear(n_observations, 128)\n",
    "        self.layer2 = nn.Linear(128, 128)\n",
    "        self.layer3 = nn.Linear(128, n_actions)\n",
    "\n",
    "    # Called with either one element to determine next action, or a batch\n",
    "    # during optimization. Returns tensor([[left0exp,right0exp]...]).\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.layer1(x))\n",
    "        x = F.relu(self.layer2(x))\n",
    "        return self.layer3(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-11-18 01:22:57.328 python[79978:7684974] +[IMKClient subclass]: chose IMKClient_Modern\n",
      "2024-11-18 01:22:57.328 python[79978:7684974] +[IMKInputSession subclass]: chose IMKInputSession_Modern\n"
     ]
    }
   ],
   "source": [
    "# BATCH_SIZE is the number of transitions sampled from the replay buffer\n",
    "# GAMMA is the discount factor as mentioned in the previous section\n",
    "# EPS_START is the starting value of epsilon\n",
    "# EPS_END is the final value of epsilon\n",
    "# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay\n",
    "# TAU is the update rate of the target network\n",
    "# LR is the learning rate of the AdamW optimizer\n",
    "BATCH_SIZE = 128\n",
    "GAMMA = 0.99\n",
    "EPS_START = 0.9\n",
    "EPS_END = 0.05\n",
    "EPS_DECAY = 1000\n",
    "TAU = 0.005\n",
    "LR = 1e-4\n",
    "\n",
    "# Get number of actions from gym action space\n",
    "n_actions = env.action_space.shape[0]\n",
    "# Get the number of state observations\n",
    "state, info = env.reset()\n",
    "if(type(env.observation_space) == gym.spaces.discrete.Discrete):\n",
    "    n_observations = env.observation_space.n\n",
    "else:\n",
    "    n_observations = len(state)\n",
    "    \n",
    "\n",
    "policy_net = DQN(n_observations, n_actions).to(device)\n",
    "target_net = DQN(n_observations, n_actions).to(device)\n",
    "target_net.load_state_dict(policy_net.state_dict())\n",
    "\n",
    "optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)\n",
    "memory = ReplayMemory(10000)\n",
    "\n",
    "\n",
    "steps_done = 0\n",
    "\n",
    "episode_durations = []\n",
    "\n",
    "\n",
    "def plot_durations(show_result=False):\n",
    "    plt.figure(1)\n",
    "    durations_t = torch.tensor(episode_durations, dtype=torch.float)\n",
    "    if show_result:\n",
    "        plt.title('Result')\n",
    "    else:\n",
    "        plt.clf()\n",
    "        plt.title('Training...')\n",
    "    plt.xlabel('Episode')\n",
    "    plt.ylabel('Duration')\n",
    "    plt.plot(durations_t.numpy())\n",
    "    # Take 100 episode averages and plot them too\n",
    "    if len(durations_t) >= 100:\n",
    "        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)\n",
    "        means = torch.cat((torch.zeros(99), means))\n",
    "        plt.plot(means.numpy())\n",
    "\n",
    "    plt.pause(0.001)  # pause a bit so that plots are updated\n",
    "    if is_ipython:\n",
    "        if not show_result:\n",
    "            display.display(plt.gcf())\n",
    "            display.clear_output(wait=True)\n",
    "        else:\n",
    "            display.display(plt.gcf())\n",
    "\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def select_action(state):\n",
    "    global steps_done\n",
    "    sample = random.random()\n",
    "    eps_threshold = EPS_END + (EPS_START - EPS_END) * \\\n",
    "        math.exp(-1. * steps_done / EPS_DECAY)\n",
    "    steps_done += 1\n",
    "    if sample > eps_threshold:\n",
    "        with torch.no_grad():\n",
    "            return policy_net(state).numpy()[0]\n",
    "    else:\n",
    "        return env.action_space.sample()\n",
    "\n",
    "def optimize_model():\n",
    "    if len(memory) < BATCH_SIZE:\n",
    "        return\n",
    "    transitions = memory.sample(BATCH_SIZE)\n",
    "    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for\n",
    "    # detailed explanation). This converts batch-array of Transitions\n",
    "    # to Transition of batch-arrays.\n",
    "    batch = Transition(*zip(*transitions))\n",
    "\n",
    "    # Compute a mask of non-final states and concatenate the batch elements\n",
    "    # (a final state would've been the one after which simulation ended)\n",
    "    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,\n",
    "                                          batch.next_state)), device=device, dtype=torch.bool)\n",
    "    non_final_next_states = torch.cat([s for s in batch.next_state\n",
    "                                                if s is not None])\n",
    "    state_batch = torch.cat(batch.state)\n",
    "    # print(batch)\n",
    "    # action_batch = torch.cat(torch.from_numpy(batch.action),)\n",
    "    # reward_batch = torch.cat(batch.reward)\n",
    "    reward_batch = torch.tensor(batch.reward)\n",
    "\n",
    "    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the\n",
    "    # columns of actions taken. These are the actions which would've been taken\n",
    "    # for each batch state according to policy_net)\n",
    "    # a =action_batch.numpy()\n",
    "    # b = policy_net(state_batch).detach().numpy()\n",
    "    # state_action_values = policy_net(state_batch).gather(1,action_batch)\n",
    "    state_action_values = policy_net(state_batch)\n",
    "\n",
    "    # Compute V(s_{t+1}) for all next states.\n",
    "    # Expected values of actions for non_final_next_states are computed based\n",
    "    # on the \"older\" target_net; selecting their best reward with max(1)[0].\n",
    "    # This is merged based on the mask, such that we'll have either the expected\n",
    "    # state value or 0 in case the state was final.\n",
    "    next_state_values = torch.zeros(BATCH_SIZE, device=device)\n",
    "    with torch.no_grad():\n",
    "        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]\n",
    "    # Compute the expected Q values\n",
    "    expected_state_action_values = (next_state_values * GAMMA) + reward_batch\n",
    "\n",
    "    # Compute Huber loss\n",
    "    criterion = nn.SmoothL1Loss()\n",
    "    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))\n",
    "\n",
    "    # Optimize the model\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    # In-place gradient clipping\n",
    "    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)\n",
    "    optimizer.step()\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    num_episodes = 600\n",
    "else:\n",
    "    num_episodes = 600\n",
    "\n",
    "for i_episode in range(num_episodes):\n",
    "    # Initialize the environment and get it's state\n",
    "    state, info = env.reset()\n",
    "    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)\n",
    "    for t in count():\n",
    "        action = select_action(state)\n",
    "        print(action)\n",
    "        observation, reward, terminated, truncated, _ = env.step(action)\n",
    "        # print(info['reward_forward'])\n",
    "        # print(env.step(action))\n",
    "        # print(info.keys())\n",
    "        reward = torch.tensor(reward,device=device)\n",
    "        # reward = torch.tensor([reward], device=device)\n",
    "        # print(reward)\n",
    "        done = terminated or truncated\n",
    "\n",
    "        if terminated:\n",
    "            next_state = None\n",
    "        else:\n",
    "            next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)\n",
    "\n",
    "        # Store the transition in memory\n",
    "        memory.push(state, action, next_state, reward)\n",
    "\n",
    "        # Move to the next state\n",
    "        state = next_state\n",
    "\n",
    "        # Perform one step of the optimization (on the policy network)\n",
    "        optimize_model()\n",
    "\n",
    "        # Soft update of the target network's weights\n",
    "        # θ′ ← τ θ + (1 −τ )θ′\n",
    "        target_net_state_dict = target_net.state_dict()\n",
    "        policy_net_state_dict = policy_net.state_dict()\n",
    "        for key in policy_net_state_dict:\n",
    "            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)\n",
    "        target_net.load_state_dict(target_net_state_dict)\n",
    "\n",
    "        if done:\n",
    "            episode_durations.append(reward.item())\n",
    "            plot_durations()\n",
    "            break\n",
    "\n",
    "print('Complete')\n",
    "plot_durations(show_result=True)\n",
    "plt.ioff()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mac-rl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
